view InteropCodeGen/ComInteropProxyGenerator.cs @ 7:a7650e26195f tip

Added support for generating Com Callable Wrappers
author Ivo Smits
date Fri, 06 May 2011 08:12:43 +0200
parents e640ca67b819
children
line wrap: on
line source
???using System;
using System.Collections.Generic;
using System.Text;
using System.IO;

namespace VBoxSDK {
	class ComInteropProxyGenerator {
		public TextWriter Output { get; private set; }
		public Boolean IncludeInheritedMembers { get; set; }
		public String TypeModifiers { get; set; }
		public Boolean GenerateDelegates { get; set; }
		public Boolean GenerateVTables { get; set; }
		public Boolean GenerateProxies { get; set; }
		public Boolean GenerateCCWs { get; set; }

		public ComInteropProxyGenerator(TextWriter output) {
			Output = output;
			GenerateDelegates = true;
			GenerateVTables = true;
			GenerateProxies = true;
			GenerateCCWs = true;
		}
		public void WriteLibrary(LibraryInfo lib) {
			if (GenerateProxies) {
				Output.WriteLine("interface IComProxy {");
				Output.WriteLine("IntPtr ComPointer { get; }");
				Output.WriteLine("void AddRef();");
				Output.WriteLine("}");
			}

			if (GenerateCCWs) {
				Output.WriteLine("public interface IComCallableWrapper {");
				Output.WriteLine("Guid IID { get; }");
				Output.WriteLine("Int32 ReferenceCount { get; set; }");
				Output.WriteLine("Object Object { get; }");
				Output.WriteLine("IntPtr Pointer { get; }");
				Output.WriteLine("}");
				Output.WriteLine("public unsafe class IUnknown_CCW {");
				Output.WriteLine("public static Dictionary<IntPtr, IComCallableWrapper> Instances = new Dictionary<IntPtr, IComCallableWrapper>();");
				Output.WriteLine("public static IUnknown_vtable GetFunctionTable() {");
				Output.WriteLine("IUnknown_vtable functions;");
				Output.WriteLine("functions.AddRef = AddRef;");
				Output.WriteLine("functions.QueryInterface = QueryInterface;");
				Output.WriteLine("functions.Release = Release;");
				Output.WriteLine("return functions;");
				Output.WriteLine("}");
				Output.WriteLine("public static HRESULT QueryInterface(IntPtr pthis, ref Guid priid, out IntPtr p) {");
				Output.WriteLine("p = pthis;");
				Output.WriteLine("if (Instances[pthis].IID != priid) return 1;");
				Output.WriteLine("Instances[pthis].ReferenceCount++;");
				Output.WriteLine("return 0;");
				Output.WriteLine("}");
				Output.WriteLine("public static HRESULT AddRef(IntPtr pthis) {");
				Output.WriteLine("Instances[pthis].ReferenceCount++;");
				Output.WriteLine("return Instances[pthis].ReferenceCount;");
				Output.WriteLine("}");
				Output.WriteLine("public static HRESULT Release(IntPtr pthis) {");
				Output.WriteLine("IComCallableWrapper cw = Instances[pthis];");
				Output.WriteLine("cw.ReferenceCount--;");
				Output.WriteLine("if (cw.ReferenceCount == 0) {");
				Output.WriteLine("Instances.Remove(pthis);");
				Output.WriteLine("if (cw is IDisposable) (cw as IDisposable).Dispose();");
				Output.WriteLine("Marshal.FreeCoTaskMem(*(IntPtr*)pthis);");
				Output.WriteLine("Marshal.FreeCoTaskMem(pthis);");
				Output.WriteLine("return 0;");
				Output.WriteLine("} else {");
				Output.WriteLine("return cw.ReferenceCount;");
				Output.WriteLine("}");
				Output.WriteLine("}");
				Output.WriteLine("}");
			}
			foreach (KeyValuePair<String, InterfaceInfo> intf in lib.Interfaces) {
				WriteInterface(intf.Value);
			}
		}
		public void WriteInterface(InterfaceInfo intf) {
			/*List<InterfaceMemberInfo> members = new List<InterfaceMemberInfo>();
			InterfaceInfo intfp = intf.Value;
			while (intfp != null) {
				members.AddRange(intfp.Members);
				intfp = intfp.Extends;
			}*/

			if (GenerateDelegates) WriteInterfaceDelegates(intf.Name, intf.Members);
			if (GenerateVTables) WriteInterfaceVTableStructure(intf.Name, intf.Extends, intf.Members);
			if (GenerateProxies) WriteInterfaceProxyClasses(intf);
			if (GenerateCCWs) WriteInterfaceCCWClass(intf);
		}
		private void WriteInterfaceDelegates(String Name, IEnumerable<InterfaceMemberInfo> Members) {
			foreach (InterfaceMemberInfo member in Members) {
				if (member is PropertyInfo) {
					PropertyInfo memberi = (PropertyInfo)member;
					if (memberi.Gettable) {
						Output.Write("{2} delegate HRESULT {0}_{1}_get_Delegate(IntPtr pThis, ", Name, memberi.Name, TypeModifiers);
						if (memberi.Type is InterfaceTypeInfo) {
							Output.WriteLine("out IntPtr value);");
						} else {
							WriteTypeComMarshalAttributes(memberi.Type, null);
							Output.WriteLine("out {0} value);", memberi.Type.Name);
						}
					}
					if (memberi.Settable) {
						Output.Write("{2} delegate HRESULT {0}_{1}_set_Delegate(IntPtr pThis, ", Name, memberi.Name, TypeModifiers);
						if (memberi.Type is InterfaceTypeInfo) {
							Output.WriteLine("IntPtr value);");
						} else {
							WriteTypeComMarshalAttributes(memberi.Type, null);
							Output.WriteLine("{0} value);", memberi.Type.Name);
						}
					}
				} else if (member is MethodInfo) {
					MethodInfo memberi = (MethodInfo)member;
					Output.Write("{0} delegate ", TypeModifiers);
					Output.Write("HRESULT ");
					Output.Write("{0}_{1}_Delegate(IntPtr mThis", Name, memberi.Name);
					foreach (MethodParameterInfo param in memberi.Parameters) {
						Output.Write(", ");
						if (param.Type is InterfaceTypeInfo) {
							if (param.Output && !param.Input) Output.Write("out ");
							else if (param.Reference) Output.Write("ref ");
							Output.Write("IntPtr p{0}", param.Name);
						} else {
							WriteTypeComMarshalAttributes(param.Type, null);
							if (param.Output && !param.Input) Output.Write("out ");
							else if (param.Reference) Output.Write("ref ");
							Output.Write("{0} p{1}", param.Type.Name, param.Name);
						}
					}
					if (memberi.ReturnType != null) {
						Output.Write(", ");
						if (memberi.ReturnType is InterfaceTypeInfo) {
							Output.Write("out IntPtr mOut");
						} else {
							WriteTypeComMarshalAttributes(memberi.ReturnType, null);
							Output.Write("out {0} mOut", memberi.ReturnType.Name);
						}
					}
					Output.WriteLine(");");
				}
			}
		}
		private void WriteInterfaceVTableStructure(String Name, InterfaceInfo Extends, IEnumerable<InterfaceMemberInfo> Members) {
			Output.WriteLine("{1} struct {0}_vtable {{", Name, TypeModifiers);
			if (Extends != null) {
				Output.WriteLine("public {0}_vtable {0};", Extends.Name);
			}
			foreach (InterfaceMemberInfo member in Members) {
				if (member is PropertyInfo) {
					PropertyInfo memberi = (PropertyInfo)member;
					if (memberi.Gettable) {
						Output.WriteLine("[MarshalAs(UnmanagedType.FunctionPtr)] public {0}_{1}_get_Delegate get_{1};", Name, memberi.Name);
					}
					if (memberi.Settable) {
						Output.WriteLine("[MarshalAs(UnmanagedType.FunctionPtr)] public {0}_{1}_set_Delegate set_{1};", Name, memberi.Name);
					}
				} else if (member is MethodInfo) {
					MethodInfo memberi = (MethodInfo)member;
					Output.WriteLine("[MarshalAs(UnmanagedType.FunctionPtr)] public {0}_{1}_Delegate {1};", Name, memberi.Name);
				}
			}
			Output.WriteLine("}");
		}
		private void WriteInterfaceMembers(ICollection<InterfaceMemberInfo> members) {
			foreach (InterfaceMemberInfo member in members) {
				if (member is PropertyInfo) {
					PropertyInfo memberi = (PropertyInfo)member;
					if (!memberi.Gettable && !memberi.Settable) continue;
					Output.WriteLine("public {0} {1} {{", memberi.Type.Name, memberi.Name);
					if (memberi.Gettable) {
						Output.WriteLine("get {");
						WriteMethodComCall("get_" + memberi.Name, null, memberi.Type);
						Output.WriteLine("}");
					}
					if (memberi.Settable) {
						Output.WriteLine("set {");
						WriteMethodComCall("set_" + memberi.Name, new MethodParameterInfo(null, memberi.Type, true, false, false));
						Output.WriteLine("}");
					}
					Output.WriteLine("}");
				} else if (member is MethodInfo) {
					MethodInfo memberi = (MethodInfo)member;
					Output.Write("public ");
					if (memberi.ReturnType == null) {
						Output.Write("void");
					} else {
						Output.Write(memberi.ReturnType.Name);
					}
					Output.Write(" {0}(", memberi.Name);
					Boolean first = true;
					foreach (MethodParameterInfo parameter in memberi.Parameters) {
						if (first) {
							first = false;
						} else {
							Output.Write(", ");
						}
						if (parameter.Output && !parameter.Input) {
							Output.Write("out ");
						} else if (parameter.Reference) {
							Output.Write("ref ");
						}
						Output.Write("{0} p{1}", parameter.Type.Name, parameter.Name);
					}
					Output.WriteLine(") {");

					WriteMethodComCall(memberi.Name, memberi.Parameters, memberi.ReturnType);

					Output.WriteLine("}");
				}
			}
		}
		private void WriteMethodComCall(String fname, params MethodParameterInfo[] parameters) {
			WriteMethodComCall(fname, (IEnumerable<MethodParameterInfo>)parameters, null);
		}
		private void WriteMethodComCall(String fname, IEnumerable<MethodParameterInfo> parameters, TypeInfo returnType) {
			if (returnType == null) {
			} else if (returnType is InterfaceTypeInfo) {
				Output.WriteLine("IntPtr retval;");
			} else {
				Output.WriteLine("{0} retval;", returnType.Name);
			}

			if (parameters != null) foreach (MethodParameterInfo parameter in parameters) {
					if (parameter.Type is InterfaceTypeInfo) {
						String pname = parameter.Name;
						if (pname == null) {
							pname = "value";
						} else {
							pname = "p" + pname;
						}
						Output.WriteLine("IntPtr l{0};", pname);
						if (parameter.Input) {
							Output.WriteLine("if ({0} == null) {{", pname);
							Output.WriteLine("l{0} = IntPtr.Zero;", pname);
							Output.WriteLine("}} else if ({0} is IComProxy) {{", pname);
							Output.WriteLine("((IComProxy){0}).AddRef();", pname);
							Output.WriteLine("l{0} = ((IComProxy){0}).ComPointer;", pname);
							Output.WriteLine("} else {");
							if (parameter.Type.Name == "IFramebuffer") {
								Output.WriteLine("l{0} = (new {1}_CCW({0})).Pointer;", pname, parameter.Type.Name);
							} else {
								Output.WriteLine("l{0} = Marshal.GetIUnknownForObject({0});", pname);
							}
							Output.WriteLine("}");
						}
					}
				}

			Output.Write("HRESULT hr = functions.{0}(Pointer", fname);
			if (parameters != null) foreach (MethodParameterInfo parameter in parameters) {
					Output.Write(", ");
					if (parameter.Output && !parameter.Input) {
						Output.Write("out ");
					} else if (parameter.Reference) {
						Output.Write("ref ");
					}
					String pname = parameter.Name;
					if (pname == null) {
						pname = "value";
					} else {
						pname = "p" + pname;
					}
					if (parameter.Type is InterfaceTypeInfo) {
						pname = "l" + pname;
					}
					Output.Write(pname);
				}
			if (returnType != null) {
				Output.Write(", out retval");
			}
			Output.WriteLine(");");
			Output.WriteLine("Marshal.ThrowExceptionForHR(hr);");

			if (parameters != null) foreach (MethodParameterInfo parameter in parameters) {
					if (parameter.Type is InterfaceTypeInfo) {
						if (parameter.Output) {
							Output.WriteLine("p{0} = lp{0} == IntPtr.Zero ? null : new {1}_Proxy(lp{0});", parameter.Name, parameter.Type.Name);
						}
					}
				}

			if (returnType == null) {
			} else if (returnType is InterfaceTypeInfo) {
				Output.WriteLine("return retval == IntPtr.Zero ? null : new {0}_Proxy(retval);", returnType.Name);
			} else {
				Output.WriteLine("return retval;");
			}
		}
		private void WriteInterfaceProxyClasses(InterfaceInfo intf) {
			if (intf.Extends != null) {
				Output.WriteLine("{1} unsafe class {0}_Proxy : {2}_Proxy, {0} {{", intf.Name, TypeModifiers, intf.Extends.Name);
				Output.WriteLine("public new static Guid IID = new Guid(\"{0}\");", intf.IID);
			} else {
				Output.WriteLine("{1} unsafe class {0}_Proxy : IComProxy, {0} {{", intf.Name, TypeModifiers);
				Output.WriteLine("public IntPtr Pointer { get; protected set; }");
				Output.WriteLine("IntPtr IComProxy.ComPointer { get { return Pointer; } }");
				Output.WriteLine("public static Guid IID = new Guid(\"{0}\");", intf.IID);
			}
			Output.WriteLine("private {0}_vtable functions;", intf.Name);

			Output.WriteLine("public {0}_Proxy(IntPtr p) : this(p, IID) {{ }}", intf.Name);

			if (intf.Extends != null) {
				Output.WriteLine("public {0}_Proxy(IntPtr p, Guid iid) : base(p, iid) {{", intf.Name, intf.Extends.Name);
				Output.WriteLine("functions = ({0}_vtable)Marshal.PtrToStructure(*(IntPtr*)Pointer, typeof({0}_vtable));", intf.Name);
				Output.WriteLine("}");
			} else {
				Output.WriteLine("public {0}_Proxy(IntPtr p, Guid iid) {{", intf.Name);
				Output.WriteLine("IUnknown_vtable ft = (IUnknown_vtable)Marshal.PtrToStructure(*(IntPtr*)p, typeof(IUnknown_vtable));");
				Output.WriteLine("HRESULT hr = ft.QueryInterface(p, ref iid, out p);");
				Output.WriteLine("Marshal.ThrowExceptionForHR(hr);");
				Output.WriteLine("ft.AddRef(p);");
				Output.WriteLine("Pointer = p;");
				Output.WriteLine("functions = ({0}_vtable)Marshal.PtrToStructure(*(IntPtr*)Pointer, typeof({0}_vtable));", intf.Name);
				Output.WriteLine("}");
			}

			if (intf.Name == "IUnknown") {
				Output.Write("~{0}_Proxy() {{ functions.Release(Pointer); }}", intf.Name);
			}

			WriteInterfaceMembers(intf.Members);
			Output.WriteLine("}");
		}
		private void WriteInterfaceCCWClass(InterfaceInfo intf) {
			if (intf.Extends == null) return;
			Boolean writeMembers = true;
			if (intf.Name == "IDispatch" || intf.Name == "IErrorInfo") writeMembers = false;
			Output.WriteLine("public unsafe class {0}_CCW : IComCallableWrapper {{", intf.Name);
			Output.WriteLine("public IntPtr Pointer { get; protected set; }");
			Output.WriteLine("public Object Object { get; private set; }");
			Output.WriteLine("public Int32 ReferenceCount { get; set; }");
			Output.WriteLine("Guid IComCallableWrapper.IID {{ get {{ return {0}_Proxy.IID; }} }}", intf.Name);

			Output.WriteLine("public static {0}_vtable GetFunctionTable() {{", intf.Name);
			Output.WriteLine("{0}_vtable functions = new {0}_vtable();", intf.Name);
			if (writeMembers) foreach (InterfaceMemberInfo member in intf.Members) {
				if (member is PropertyInfo) {
					PropertyInfo memberi = (PropertyInfo)member;
					if (memberi.Gettable) Output.WriteLine("functions.get_{0} = get_{0};", memberi.Name);
					if (memberi.Settable) Output.WriteLine("functions.set_{0} = set_{0};", memberi.Name);
				} else if (member is MethodInfo) {
					MethodInfo memberi = (MethodInfo)member;
					Output.WriteLine("functions.{0} = {0};", memberi.Name);
				}
			}
			if (intf.Extends != null) Output.WriteLine("functions.{0} = {0}_CCW.GetFunctionTable();", intf.Extends.Name);
			Output.WriteLine("return functions;");
			Output.WriteLine("}");

			Output.WriteLine("public {0}_CCW({0} p) {{", intf.Name);
			Output.WriteLine("Object = p;");
			Output.WriteLine("Pointer = Marshal.AllocCoTaskMem(IntPtr.Size);");
			Output.WriteLine("*(IntPtr*)Pointer = Marshal.AllocCoTaskMem(Marshal.SizeOf(typeof({0}_vtable)));", intf.Name);
			Output.WriteLine("Marshal.StructureToPtr(GetFunctionTable(), *(IntPtr*)Pointer, false);");
			Output.WriteLine("IUnknown_CCW.Instances[Pointer] = this;");
			Output.WriteLine("}");
			Output.WriteLine("private static {0} GetInstance(IntPtr pthis) {{", intf.Name);
			Output.WriteLine("return ({0})IUnknown_CCW.Instances[pthis].Object;", intf.Name);
			Output.WriteLine("}");

			if (writeMembers) foreach (InterfaceMemberInfo member in intf.Members) {
				if (member is PropertyInfo) {
					PropertyInfo memberi = (PropertyInfo)member;
					if (memberi.Gettable) {
						Output.Write("private static HRESULT get_{0}(IntPtr pthis, out ", memberi.Name);
						if (memberi.Type is InterfaceTypeInfo) {
							Output.Write("IntPtr");
						} else {
							Output.Write(memberi.Type.Name);
						}
						Output.WriteLine(" ret) {");
						if (memberi.Type is InterfaceTypeInfo) {
							Output.WriteLine("{1} iret = GetInstance(pthis).{0};", memberi.Name, memberi.Type.Name);
							Output.WriteLine("if (iret == null) { ret = IntPtr.Zero; } else if (iret is IComProxy) { ret = ((IComProxy)iret).ComPointer; } else {");
							Output.WriteLine("ret = (new {0}_CCW(iret)).Pointer; }}", memberi.Type.Name);
						} else {
							Output.WriteLine("ret = GetInstance(pthis).{0};", memberi.Name);
						}
						Output.WriteLine("return 0;");
						Output.WriteLine("}");
					}
					if (memberi.Settable) {
						Output.Write("private static HRESULT set_{0}(IntPtr pthis, ", memberi.Name);
						if (memberi.Type is InterfaceTypeInfo) {
							Output.Write("IntPtr");
						} else {
							Output.Write(memberi.Type.Name);
						}
						Output.WriteLine(" value) {");
						if (memberi.Type is InterfaceTypeInfo) {
							Output.WriteLine("{0} ivalue;", memberi.Type.Name);
							Output.WriteLine("if (value == IntPtr.Zero) { ivalue = null; } else if (IUnknown_CCW.Instances.ContainsKey(value)) { ivalue = ({0})(IUnknown_CCW.Instances[value].Object) } else { ivalue = new {0}_CCW(value); }", memberi.Type.Name);
							Output.WriteLine("GetInstance(pthis).{0} = ivalue;", memberi.Name);
						} else {
							Output.WriteLine("GetInstance(pthis).{0} = value;", memberi.Name);
						}
						Output.WriteLine("return 0;");
						Output.WriteLine("}");
					}
				} else if (member is MethodInfo) {
					MethodInfo memberi = (MethodInfo)member;
					Output.Write("private static HRESULT {0}(IntPtr pthis", memberi.Name);
					foreach (MethodParameterInfo parameter in memberi.Parameters) {
						Output.Write(", ");
						if (parameter.Output && !parameter.Input) {
							Output.Write("out ");
						} else if (parameter.Reference) {
							Output.Write("ref ");
						}
						if (parameter.Type is InterfaceTypeInfo) {
							Output.Write("IntPtr");
						} else {
							Output.Write(parameter.Type.Name);
						}
						Output.Write(" p{0}", parameter.Name);
					}
					if (memberi.ReturnType != null) {
						Output.Write(", out ");
						if (memberi.ReturnType is InterfaceTypeInfo) {
							Output.Write("IntPtr");
						} else {
							Output.Write(memberi.ReturnType.Name);
						}
						Output.Write(" retval");
					}
					Output.WriteLine(") {");

					foreach (MethodParameterInfo parameter in memberi.Parameters) {
						if (parameter.Type is InterfaceTypeInfo) {
							Output.WriteLine("{0} i{1};", parameter.Type.Name, parameter.Name);
							if (parameter.Input) {
								Output.WriteLine("if (p{0} == IntPtr.Zero) {{ i{0} = null; }} else if (IUnknown_CCW.Instances.ContainsKey(p{0})) {{ i{0} = ({1})(IUnknown_CCW.Instances[p{0}].Object); }} else {{ i{0} = new {1}_Proxy(p{0}); }}", parameter.Name, parameter.Type.Name);
							}
						}
					}

					if (memberi.ReturnType != null) {
						if (memberi.ReturnType is InterfaceTypeInfo) {
							Output.Write("{0} iretval = ", memberi.ReturnType.Name);
						} else {
							Output.Write("retval = ");
						}
					}
					Output.Write("GetInstance(pthis).{0}(", memberi.Name);
					bool first = true;
					foreach (MethodParameterInfo parameter in memberi.Parameters) {
						if (first) {
							first = false;
						} else {
							Output.Write(", ");
						}
						if (parameter.Output && !parameter.Input) {
							Output.Write("out ");
						} else if (parameter.Reference) {
							Output.Write("ref ");
						}
						if (parameter.Type is InterfaceTypeInfo) {
							Output.Write("i{0}", parameter.Name);
						} else {
							Output.Write("p{0}", parameter.Name);
						}
					}
					Output.WriteLine(");");

					foreach (MethodParameterInfo parameter in memberi.Parameters) {
						if (parameter.Type is InterfaceTypeInfo) {
							if (parameter.Output || parameter.Reference) {
								Output.WriteLine("if (i{0} == null) {{ p{0} = IntPtr.Zero; }} else if (i{0} is IComProxy) {{ p{0} = ((IComProxy)i{0}).ComPointer; }} else {{", parameter.Name);
								Output.WriteLine("p{1} = (new {0}_CCW(i{1})).Pointer; }}", parameter.Type.Name, parameter.Name);
							}
						}
					}
					if (memberi.ReturnType is InterfaceTypeInfo) {
						Output.WriteLine("if (iretval == null) { retval = IntPtr.Zero; } else if (iretval is IComProxy) { retval = ((IComProxy)iretval).ComPointer; } else {");
						Output.WriteLine("retval = (new {0}_CCW(iretval)).Pointer; }}", memberi.ReturnType.Name);
					}

					Output.WriteLine("return 0;");
					Output.WriteLine("}");
				}
			}
			Output.WriteLine("}");
		}

		public void WriteTypeComMarshalAttributes(TypeInfo type, String paramType) {
			if (type == null) return;
			String MarshalAs = null;
			if (type is InterfaceTypeInfo) {
				MarshalAs = "Interface";
			} else if (type is StringTypeInfo) {
				MarshalAs = (type as StringTypeInfo).UnmanagedType.ToString();
			}
			if (MarshalAs != null) {
				Output.Write("[");
				if (paramType != null) Output.Write("{0}: ", paramType);
				Output.Write("MarshalAs(UnmanagedType.{0})] ", MarshalAs);
			}
		}
	}
}